{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "39af8c33",
   "metadata": {},
   "source": [
    "# Wide & Deep Recommendation for large scale data - Feature Engineering"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "52ea4cc7",
   "metadata": {},
   "source": [
    "This demo uses the [Twitter Recsys Challenge 2021 dataset](http://www.recsyschallenge.com/2021/) and the [Wide & Deep Model](https://arxiv.org/abs/1606.07792). The dataset includes 46 million users and 340 million tweets (items) and each record contains the tweet along with engagement features, user features, and tweet features.\n",
    "\n",
    "At the very beginning, let's have a high-level overview of general recommendation systems. The diagram below demonstrates the common components of a recommendation system, which typically consists of three stages:\n",
    "- Offline: Perform feature engineering on the raw data and use the preprocessed data to train embeddings and deep learning models.\n",
    "\n",
    "- Nearline: Retrieve user/item profiles and keep them in the Key-Value store. Make updates to the profiles and fine-tune the deep learning model from time to time.\n",
    "\n",
    "- Online: Trigger the recommendation process whenever a user request comes. Recall service generates candidates from millions of items based on embedding similarity and ranking services uses the trained deep learning model to re-rank the candidates for the final recommendation results."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54ebeebf",
   "metadata": {},
   "source": [
    "<img src=\"figures/overview-recsys.png\" alt=\"overview-recsys\" width=\"750\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5087be1c",
   "metadata": {},
   "source": [
    "This notebook demonstrates some common data preprocessing and feature engineering steps for Wide & Deep Learning on the Twitter Recsys Challenge 2021 dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad9effd3",
   "metadata": {},
   "source": [
    "First of all, we import the necessary packages in BigDL for cluster initialization and built-in recommendation operations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "313f2a30",
   "metadata": {},
   "outputs": [],
   "source": [
    "from bigdl.orca import init_orca_context, stop_orca_context, OrcaContext\n",
    "from bigdl.friesian.feature import FeatureTable, StringIndex"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ce186eb",
   "metadata": {},
   "source": [
    "Initialize the environment on the YARN cluster. You simply need to prepare the Python environment on the driver node with [Anaconda](https://www.anaconda.com/products/individual) and BigDL will automatically distribute and prepare the environment for you across the cluster.\n",
    "Besides, you can specify the allocated resources for this application during the initialization, including the number of nodes, cores and the amount of memory to use, etc. BigDL provides detailed guidance to be easily deployed on [Hadoop/YARN](https://bigdl.readthedocs.io/en/latest/doc/UserGuide/hadoop.html) or [K8S](https://bigdl.readthedocs.io/en/latest/doc/UserGuide/k8s.html) clusters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4880fe53",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initializing orca context\n",
      "Current pyspark location is : /root/anaconda3/envs/bigdl/lib/python3.7/site-packages/pyspark/__init__.py\n",
      "Initializing SparkContext for yarn-client mode\n",
      "Start to pack current python env\n",
      "Collecting packages...\n",
      "Packing environment at '/root/anaconda3/envs/bigdl' to '/tmp/tmp0d6y97wf/python_env.tar.gz'\n",
      "[########################################] | 100% Completed | 15.0s\n",
      "Packing has been completed: /tmp/tmp0d6y97wf/python_env.tar.gz\n",
      "pyspark_submit_args is: --master yarn --deploy-mode client --archives /tmp/tmp0d6y97wf/python_env.tar.gz#python_env --driver-cores 4 --driver-memory 36g --num-executors 6 --executor-cores 36 --executor-memory 96g --driver-class-path /root/anaconda3/envs/bigdl/lib/python3.7/site-packages/bigdl/share/dllib/lib/bigdl-dllib-spark_2.4.6-2.0.0-jar-with-dependencies.jar:/root/anaconda3/envs/bigdl/lib/python3.7/site-packages/bigdl/share/orca/lib/bigdl-orca-spark_2.4.6-2.0.0-jar-with-dependencies.jar:/root/anaconda3/envs/bigdl/lib/python3.7/site-packages/bigdl/share/friesian/lib/bigdl-friesian-spark_2.4.6-2.0.0-jar-with-dependencies.jar pyspark-shell\n",
      "2022-03-21 13:22:32 WARN  NativeCodeLoader:62 - Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n",
      "2022-03-21 13:22:33 WARN  Client:66 - Neither spark.yarn.jars nor spark.yarn.archive is set, falling back to uploading libraries under SPARK_HOME.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting default log level to \"WARN\".\n",
      "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2022-03-21 13:23:42,002 Thread-5 WARN The bufferSize is set to 4000 but bufferedIo is false: false\n",
      "2022-03-21 13:23:42,005 Thread-5 WARN The bufferSize is set to 4000 but bufferedIo is false: false\n",
      "2022-03-21 13:23:42,006 Thread-5 WARN The bufferSize is set to 4000 but bufferedIo is false: false\n",
      "2022-03-21 13:23:42,007 Thread-5 WARN The bufferSize is set to 4000 but bufferedIo is false: false\n",
      "22-03-21 13:23:42 [Thread-5] INFO  Engine$:121 - Auto detect executor number and executor cores number\n",
      "22-03-21 13:23:42 [Thread-5] INFO  Engine$:123 - Executor number is 6 and executor cores number is 36\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "User settings:\n",
      "\n",
      "   KMP_AFFINITY=granularity=fine,compact,1,0\n",
      "   KMP_BLOCKTIME=0\n",
      "   KMP_SETTINGS=1\n",
      "   OMP_NUM_THREADS=1\n",
      "\n",
      "Effective settings:\n",
      "\n",
      "   KMP_ABORT_DELAY=0\n",
      "   KMP_ADAPTIVE_LOCK_PROPS='1,1024'\n",
      "   KMP_ALIGN_ALLOC=64\n",
      "   KMP_ALL_THREADPRIVATE=224\n",
      "   KMP_ATOMIC_MODE=2\n",
      "   KMP_BLOCKTIME=0\n",
      "   KMP_CPUINFO_FILE: value is not defined\n",
      "   KMP_DETERMINISTIC_REDUCTION=false\n",
      "   KMP_DEVICE_THREAD_LIMIT=2147483647\n",
      "   KMP_DISP_HAND_THREAD=false\n",
      "   KMP_DISP_NUM_BUFFERS=7\n",
      "   KMP_DUPLICATE_LIB_OK=false\n",
      "   KMP_FORCE_REDUCTION: value is not defined\n",
      "   KMP_FOREIGN_THREADS_THREADPRIVATE=true\n",
      "   KMP_FORKJOIN_BARRIER='2,2'\n",
      "   KMP_FORKJOIN_BARRIER_PATTERN='hyper,hyper'\n",
      "   KMP_FORKJOIN_FRAMES=true\n",
      "   KMP_FORKJOIN_FRAMES_MODE=3\n",
      "   KMP_GTID_MODE=3\n",
      "   KMP_HANDLE_SIGNALS=false\n",
      "   KMP_HOT_TEAMS_MAX_LEVEL=1\n",
      "   KMP_HOT_TEAMS_MODE=0\n",
      "   KMP_INIT_AT_FORK=true\n",
      "   KMP_ITT_PREPARE_DELAY=0\n",
      "   KMP_LIBRARY=throughput\n",
      "   KMP_LOCK_KIND=queuing\n",
      "   KMP_MALLOC_POOL_INCR=1M\n",
      "   KMP_MWAIT_HINTS=0\n",
      "   KMP_NUM_LOCKS_IN_BLOCK=1\n",
      "   KMP_PLAIN_BARRIER='2,2'\n",
      "   KMP_PLAIN_BARRIER_PATTERN='hyper,hyper'\n",
      "   KMP_REDUCTION_BARRIER='1,1'\n",
      "   KMP_REDUCTION_BARRIER_PATTERN='hyper,hyper'\n",
      "   KMP_SCHEDULE='static,balanced;guided,iterative'\n",
      "   KMP_SETTINGS=true\n",
      "   KMP_SPIN_BACKOFF_PARAMS='4096,100'\n",
      "   KMP_STACKOFFSET=64\n",
      "   KMP_STACKPAD=0\n",
      "   KMP_STACKSIZE=8M\n",
      "   KMP_STORAGE_MAP=false\n",
      "   KMP_TASKING=2\n",
      "   KMP_TASKLOOP_MIN_TASKS=0\n",
      "   KMP_TASK_STEALING_CONSTRAINT=1\n",
      "   KMP_TEAMS_THREAD_LIMIT=56\n",
      "   KMP_TOPOLOGY_METHOD=all\n",
      "   KMP_USER_LEVEL_MWAIT=false\n",
      "   KMP_USE_YIELD=1\n",
      "   KMP_VERSION=false\n",
      "   KMP_WARNINGS=true\n",
      "   OMP_AFFINITY_FORMAT='OMP: pid %P tid %i thread %n bound to OS proc set {%A}'\n",
      "   OMP_ALLOCATOR=omp_default_mem_alloc\n",
      "   OMP_CANCELLATION=false\n",
      "   OMP_DEBUG=disabled\n",
      "   OMP_DEFAULT_DEVICE=0\n",
      "   OMP_DISPLAY_AFFINITY=false\n",
      "   OMP_DISPLAY_ENV=false\n",
      "   OMP_DYNAMIC=false\n",
      "   OMP_MAX_ACTIVE_LEVELS=2147483647\n",
      "   OMP_MAX_TASK_PRIORITY=0\n",
      "   OMP_NESTED=false\n",
      "   OMP_NUM_THREADS='1'\n",
      "   OMP_PLACES: value is not defined\n",
      "   OMP_PROC_BIND='intel'\n",
      "   OMP_SCHEDULE='static'\n",
      "   OMP_STACKSIZE=8M\n",
      "   OMP_TARGET_OFFLOAD=DEFAULT\n",
      "   OMP_THREAD_LIMIT=2147483647\n",
      "   OMP_TOOL=enabled\n",
      "   OMP_TOOL_LIBRARIES: value is not defined\n",
      "   OMP_WAIT_POLICY=PASSIVE\n",
      "   KMP_AFFINITY='noverbose,warnings,respect,granularity=fine,compact,1,0'\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "22-03-21 13:23:42 [Thread-5] INFO  ThreadPool$:95 - Set mkl threads to 1 on thread 28\n",
      "2022-03-21 13:23:42 WARN  SparkContext:66 - Using an existing SparkContext; some configuration may not take effect.\n",
      "22-03-21 13:23:42 [Thread-5] INFO  Engine$:446 - Find existing spark context. Checking the spark conf...\n",
      "cls.getname: com.intel.analytics.bigdl.dllib.utils.python.api.Sample\n",
      "BigDLBasePickler registering: bigdl.dllib.utils.common  Sample\n",
      "cls.getname: com.intel.analytics.bigdl.dllib.utils.python.api.EvaluatedResult\n",
      "BigDLBasePickler registering: bigdl.dllib.utils.common  EvaluatedResult\n",
      "cls.getname: com.intel.analytics.bigdl.dllib.utils.python.api.JTensor\n",
      "BigDLBasePickler registering: bigdl.dllib.utils.common  JTensor\n",
      "cls.getname: com.intel.analytics.bigdl.dllib.utils.python.api.JActivity\n",
      "BigDLBasePickler registering: bigdl.dllib.utils.common  JActivity\n"
     ]
    }
   ],
   "source": [
    "# To display terminal's stdout and stderr in the Jupyter notebook.\n",
    "OrcaContext.log_output = True\n",
    "\n",
    "cluster_mode = \"yarn\"\n",
    "\n",
    "executor_cores = 36\n",
    "num_executor = 6\n",
    "executor_memory = \"96g\"\n",
    "driver_cores = 4\n",
    "driver_memory = \"36g\"\n",
    "conf = {\"spark.network.timeout\": \"10000000\",\n",
    "        \"spark.sql.broadcastTimeout\": \"7200\",\n",
    "        \"spark.sql.shuffle.partitions\": \"2000\",\n",
    "        \"spark.locality.wait\": \"0s\",\n",
    "        \"spark.sql.crossJoin.enabled\": \"true\",\n",
    "        \"spark.task.cpus\": \"1\",\n",
    "        \"spark.executor.heartbeatInterval\": \"200s\",\n",
    "        \"spark.driver.maxResultSize\": \"40G\",\n",
    "        \"spark.eventLog.enabled\": \"true\",\n",
    "        \"spark.app.name\": \"recsys-demo-preprocess\",\n",
    "        \"spark.debug.maxToStringFields\": \"100\"}\n",
    "if cluster_mode == \"local\":  # For local machine\n",
    "    sc = init_orca_context(cluster_mode=\"local\",\n",
    "                           cores=executor_cores, memory=executor_memory)\n",
    "elif cluster_mode == \"yarn\":  # For Hadoop/YARN cluster\n",
    "    sc = init_orca_context(cluster_mode=\"yarn\", cores=executor_cores,\n",
    "                           num_nodes=num_executor, memory=executor_memory,\n",
    "                           driver_cores=driver_cores, driver_memory=driver_memory,\n",
    "                           conf=conf)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "889475d8",
   "metadata": {},
   "source": [
    "Load raw train and validation data as FeatureTables."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f1c82b80",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of train records: 747694282\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r\n",
      "[Stage 4:==================================================>   (200 + 14) / 214]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of validation records: 14461760\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r\n",
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "train_path = \"/path/to/train/data\"\n",
    "valid_path = \"/path/to/valid/data\"\n",
    "\n",
    "train_tbl = FeatureTable.read_parquet(train_path)\n",
    "valid_tbl = FeatureTable.read_parquet(valid_path)\n",
    "\n",
    "print(\"Total number of train records: {}\".format(train_tbl.size()))\n",
    "print(\"Total number of validation records: {}\".format(valid_tbl.size()))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "53e80641",
   "metadata": {},
   "source": [
    "Common features for recommendation include boolean features, categorical features (mostly string) and continuous (numeric) features. Several typical features of these types in this dataset are listed below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6bbb49c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "bool_cols = [\n",
    "    'engaged_with_user_is_verified',\n",
    "    'enaging_user_is_verified'\n",
    "]\n",
    "\n",
    "count_cols = [\n",
    "    'engaged_with_user_follower_count',\n",
    "    'engaged_with_user_following_count',\n",
    "    'enaging_user_follower_count',\n",
    "    'enaging_user_following_count'\n",
    "]\n",
    "\n",
    "cat_cols = [\n",
    "    'present_media',\n",
    "    'tweet_type',\n",
    "    'language'\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f2eff1d",
   "metadata": {},
   "source": [
    "<img src=\"figures/feature.png\" alt=\"feature\" width=\"750\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3a610c1e",
   "metadata": {},
   "source": [
    "Now we let's start the data preprocessing for both the train and validation dataset! With the built-in high-level preprocessing operations in FeatureTable, you can achieve this using only several lines of code. :) \n",
    "\n",
    "- Fill null with default values.\n",
    "- For boolean features, we simply cast them to integers (either 0 or 1).\n",
    "- For categorical features:\n",
    "    - If all the categories are already known (for example in this dataset, there will be only 13 present_media options and 3 tweet_type options shown below), to save computation, you can directly assign each category with an id by yourself and use the string_index map to encode the features to the corresponding ids.\n",
    "    - Alternatively, the StringIndex for the categories that appear in the dataset will be generated first and then the original features will be encoded to the corresponding ids.\n",
    "- For continuous features:\n",
    "    - If the values vary along a large range, for example in this dataset, the following/follower counts may vary from 0 to million, we can put them into discrete bins and assign each bin with an index. In this case, the resulting features will actually be converted to categorical features. In the following example, 8 bins are used for following/follower counts (range: <1, 1-1e2, 1e2-1e3, 1e3-1e4, 1e4-1e5, 1e5-1e6, 1e6-1e7, >1e7) and they will be represented by 0 to 7 respectively.\n",
    "    - If the values don't vary a lot, min max scaling is a common approach to rescale the continuous values to the range [0, 1]."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0e3ce568",
   "metadata": {},
   "outputs": [],
   "source": [
    "media_map = {\n",
    "    '': 0,\n",
    "    'GIF': 1,\n",
    "    'GIF_GIF': 2,\n",
    "    'GIF_Photo': 3,\n",
    "    'GIF_Video': 4,\n",
    "    'Photo': 5,\n",
    "    'Photo_GIF': 6,\n",
    "    'Photo_Photo': 7,\n",
    "    'Photo_Video': 8,\n",
    "    'Video': 9,\n",
    "    'Video_GIF': 10,\n",
    "    'Video_Photo': 11,\n",
    "    'Video_Video': 12\n",
    "}\n",
    "\n",
    "type_map = {\n",
    "    'Quote': 0,\n",
    "    'Retweet': 1,\n",
    "    'TopLevel': 2,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ecdd941f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess(tbl):\n",
    "    tbl = tbl.fillna(\"\", \"present_media\")\n",
    "    tbl = tbl.cast(bool_cols + count_cols, \"int\")  # cast bool and long to int\n",
    "    tbl = tbl.cut_bins(columns=count_cols,\n",
    "                       bins=[1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7],\n",
    "                       out_cols=count_cols)\n",
    "    if \"present_media\" in cat_cols:\n",
    "        process_media = lambda x: '_'.join(x.split('\\t')[:2])\n",
    "        tbl = tbl.apply(\"present_media\", \"present_media\", process_media, \"string\")\n",
    "        tbl = tbl.encode_string(\"present_media\", media_map)\n",
    "    if \"tweet_type\" in cat_cols:\n",
    "        tbl = tbl.encode_string(\"tweet_type\", type_map)\n",
    "\n",
    "    return tbl\n",
    "\n",
    "\n",
    "train_tbl = preprocess(train_tbl)\n",
    "valid_tbl = preprocess(valid_tbl)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9097b819",
   "metadata": {},
   "source": [
    "For the language feature, we first of all assign each languages that appear in the train dataset with ids starting from 1 and 0 is reserved for unknown languages. Note that for the validation dataset, we use the StringIndex generated by the train dataset to encode its language features. It is a common case that the new dataset may have unseen categories and we use the reserved index 0 to fill null values and encode them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8c0a22cf",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The number of languages: 66\n"
     ]
    }
   ],
   "source": [
    "if \"language\" in cat_cols:\n",
    "    train_tbl, language_idx = train_tbl.category_encode(\"language\")\n",
    "    valid_tbl = valid_tbl.encode_string(\"language\", language_idx)\n",
    "    valid_tbl = valid_tbl.fillna(0, \"language\")\n",
    "\n",
    "    print(\"The number of languages: {}\".format(language_idx.size()))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61238107",
   "metadata": {},
   "source": [
    "Besides the preprocessing steps above, we can also generate new features from the existing features. For example,\n",
    "\n",
    "- Generating cross columns given multiple categorical columns is a common technique for Wide & Deep learning to memorize the co-appearance of joint features.\n",
    "- We also count the number of hashtags, present domains and present links and treat them as continuous features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "0ccc713e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_features(tbl):\n",
    "    cross_cols = [['present_media', 'language']]\n",
    "    cross_dims = [600]\n",
    "    tbl = tbl.cross_columns(cross_cols, cross_dims)  # The resulting cross column will have name \"present_media_language\"\n",
    "\n",
    "    count_func = lambda x: str(x).count('\\t') + 1 if x else 0\n",
    "    tbl = tbl.apply(\"hashtags\", \"len_hashtags\", count_func, \"int\") \\\n",
    "        .apply(\"present_domains\", \"len_domains\", count_func, \"int\") \\\n",
    "        .apply(\"present_links\", \"len_links\", count_func, \"int\")\n",
    "    return tbl\n",
    "\n",
    "\n",
    "train_tbl = generate_features(train_tbl)\n",
    "valid_tbl = generate_features(valid_tbl)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "019762d6",
   "metadata": {},
   "source": [
    "For the number of hashtags, present domains and present links, we use min max scaling to rescale them to the range [0, 1] as described above. Similarly, we use the min and max stats of the train dataset to transform the validation dataset as it is a common case that the new dataset may have out-of-range values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "e2199235",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "len_cols = ['len_hashtags',\n",
    "            'len_domains',\n",
    "            'len_links']\n",
    "\n",
    "train_tbl, min_max_dict = train_tbl.min_max_scale(len_cols)\n",
    "valid_tbl = valid_tbl.transform_min_max_scale(len_cols, min_max_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "98f4bcaf",
   "metadata": {},
   "source": [
    "Besides the preprocessing and feature engineering operations described above, BigDL provides a lot more built-in operations including: target encoding, count encoding, difference lag, negative sampling, etc. See [here](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Friesian/feature.html#bigdl.friesian.feature.table.FeatureTable) for more details and API usage."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "256badad",
   "metadata": {},
   "source": [
    "Finally, we are to process the label for supervised learning. The four timestamp features indicate whether the engaging user interact with the tweet and they will be jointly used to produce the label.\n",
    "\n",
    "If one of the timestamps is not null, namely the engaging user has at least one interaction with the tweet, the record will be treated as a positive sample (i.e. having label 1)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "7d55982c",
   "metadata": {},
   "outputs": [],
   "source": [
    "timestamp_cols = [\n",
    "    'reply_timestamp',\n",
    "    'retweet_timestamp',\n",
    "    'retweet_with_comment_timestamp',\n",
    "    'like_timestamp'\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "74693e71",
   "metadata": {},
   "outputs": [],
   "source": [
    "def transform_label(tbl):\n",
    "    tbl = tbl.cast(timestamp_cols, \"int\")\n",
    "    tbl = tbl.fillna(0, timestamp_cols)\n",
    "    gen_label = lambda x: 1 if max(x) > 0 else 0\n",
    "    tbl = tbl.apply(in_col=timestamp_cols, out_col=\"label\", func=gen_label, dtype=\"int\")\n",
    "    return tbl\n",
    "\n",
    "\n",
    "train_tbl = transform_label(train_tbl)\n",
    "valid_tbl = transform_label(valid_tbl)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c508f59d",
   "metadata": {},
   "source": [
    "Now we are all settled! We have finished all the preprocessing steps for this dataset.\n",
    "\n",
    "Let's take a look at the preprocessed dataset and finally save the preprocessed data to be used for Wide & Deep training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "8c68bcbc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----------------------------+------------------------+-------------+----------+--------+\n",
      "|engaged_with_user_is_verified|enaging_user_is_verified|present_media|tweet_type|language|\n",
      "+-----------------------------+------------------------+-------------+----------+--------+\n",
      "|                            0|                       0|            0|         1|      40|\n",
      "|                            0|                       0|            5|         1|      43|\n",
      "|                            0|                       0|            7|         2|      43|\n",
      "|                            1|                       0|            0|         1|      43|\n",
      "|                            0|                       0|            0|         1|      43|\n",
      "+-----------------------------+------------------------+-------------+----------+--------+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "train_tbl.select(bool_cols + cat_cols).show(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "e4187d8b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------------------------------+---------------------------------+---------------------------+----------------------------+\n",
      "|engaged_with_user_follower_count|engaged_with_user_following_count|enaging_user_follower_count|enaging_user_following_count|\n",
      "+--------------------------------+---------------------------------+---------------------------+----------------------------+\n",
      "|                               2|                                2|                          1|                           1|\n",
      "|                               2|                                2|                          2|                           2|\n",
      "|                               4|                                4|                          2|                           2|\n",
      "|                               5|                                4|                          2|                           2|\n",
      "|                               2|                                2|                          3|                           2|\n",
      "+--------------------------------+---------------------------------+---------------------------+----------------------------+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r\n",
      "[Stage 30:>                                                         (0 + 1) / 1]\r\n",
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "train_tbl.select(count_cols).show(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "46c7d5f6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r\n",
      "[Stage 31:>                                                         (0 + 1) / 1]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------------+-----------+---------+----------------------+-----+\n",
      "|len_hashtags|len_domains|len_links|present_media_language|label|\n",
      "+------------+-----------+---------+----------------------+-----+\n",
      "|         0.0|        0.0|      0.0|                   123|    0|\n",
      "|         0.0|        0.0|      0.0|                   281|    1|\n",
      "|         0.0|        0.0|      0.0|                   463|    0|\n",
      "|         0.0|        0.0|      0.0|                   126|    1|\n",
      "|         0.0|        0.0|      0.0|                   126|    0|\n",
      "+------------+-----------+---------+----------------------+-----+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r\n",
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "train_tbl.select(len_cols + [\"present_media_language\", \"label\"]).show(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "448e5ab5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "train_tbl.write_parquet(\"/path/to/preprocessed/train/data\")\n",
    "valid_tbl.write_parquet(\"/path/to/preprocessed/valid/data\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f7568c7",
   "metadata": {},
   "source": [
    "Also save the StringIndex for each categorical feature to parquet files. The mapping from string to its corresponding id would be used in the later stages (e.g. to encode new data)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "3b352183",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Stopping orca context\n"
     ]
    }
   ],
   "source": [
    "model_path = \"/path/to/models/data\"\n",
    "\n",
    "if \"language\" in cat_cols:\n",
    "    language_idx.write_parquet(model_path)  # Saved to model_path/language.parquet\n",
    "\n",
    "if \"present_media\" in cat_cols:\n",
    "    media_idx = StringIndex.from_dict(media_map, \"present_media\")\n",
    "    media_idx.write_parquet(model_path)  # Saved to model_path/present_media.parquet\n",
    "\n",
    "if \"tweet_type\" in cat_cols:\n",
    "    type_idx = StringIndex.from_dict(type_map, \"tweet_type\")\n",
    "    type_idx.write_parquet(model_path)  # Saved to model_path/tweet_type.parquet\n",
    "\n",
    "    \n",
    "stop_orca_context()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.13 ('friesian')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  },
  "vscode": {
   "interpreter": {
    "hash": "6c9c6c36520220f3300aaef54e998202ae22aa0de07e5d4c5d990becb2915fb2"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
